from typing import List
import numpy as np

from src.geometry import LineSegment, Distance

from .node import Node
from .rrt_map import RRTMap
from .rrt_parameter import RRTParameter
from .rrt_planner import RRTPlanner


class RRTStarPlanner(RRTPlanner):
    def __init__(self, rrt_map: RRTMap, rrt_parameter: RRTParameter) -> None:
        super().__init__(rrt_map, rrt_parameter)

    def find_near_nodes(self, new_node: Node) -> List[int]:
        n_node = len(self.nodes)
        r = 50.0 * np.sqrt((np.log(n_node) / n_node))
        d_list = [Distance.point_to_point(node.get_point(), new_node.get_point()) for node in self.nodes]
        near_inds = [d_list.index(d) for d in d_list if d <= r ** 2]
        return near_inds

    def choose_parent(self, new_node: Node, near_inds: List[int]) -> Node:
        if len(near_inds) == 0:
            return new_node

        d_list = []
        for i in near_inds:
            if not self.check_collision(LineSegment(self.nodes[i].get_point(), new_node.get_point())):
                d = Distance.point_to_point(self.nodes[i].get_point(), new_node.get_point())
                d_list.append(self.nodes[i].cost + d)
            else:
                d_list.append(float("inf"))

        min_cost = min(d_list)
        min_ind = near_inds[d_list.index(min_cost)]

        if min_cost == float('inf'):
            return new_node

        new_node.set_cost(min_cost)
        new_node.set_parent(min_ind)

        return new_node

    def rewire(self, new_node: Node, near_inds: List[int]) -> None:
        n_node = len(self.nodes)
        for i in near_inds:
            near_node = self.nodes[i]

            d = Distance.point_to_point(near_node.get_point(), new_node.get_point())

            s_cost = new_node.cost + d

            if near_node.cost > s_cost:
                if not self.check_collision(LineSegment(near_node.get_point(), new_node.get_point())):
                    near_node.parent = n_node - 1
                    near_node.cost = s_cost

    def add_node(self, node: Node) -> None:
        near_inds = self.find_near_nodes(node)
        node = self.choose_parent(node, near_inds)

        self.nodes.append(node)
        self.rewire(node, near_inds)

    def get_path_and_length(self) -> bool:
        path = self.get_final_course()
        path_length = self.get_path_length()

        if self.last_path_length > path_length:
            self.path = path
            self.last_path_length = path_length

        return False

    def set_goal(self) -> None:
        near_inds = self.find_near_nodes(self.goal)
        self.goal = self.choose_parent(self.goal, near_inds)
